We’ll go through some diagnostics using arviz .
Step one is to load some data. Rather than going through a whole modelling workflow, we’ll just take one of the example MCMC outputs that arviz provides via the function load_arviz_data .
This particular MCMC output has to do with measurements of soil radioactivity in the USA. You can read more about it here .
import arviz as az
import numpy as np
import xarray as xr
idata = az.load_arviz_data("radon" )
idata
posterior
<xarray.Dataset> Size: 4MB
Dimensions: (chain: 4, draw: 500, g_coef: 2, County: 85)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
* g_coef (g_coef) <U9 72B 'intercept' 'slope'
* County (County) <U17 6kB 'AITKIN' 'ANOKA' ... 'WRIGHT' 'YELLOW MEDICINE'
Data variables:
g (chain, draw, g_coef) float64 32kB ...
za_county (chain, draw, County) float64 1MB ...
b (chain, draw) float64 16kB ...
sigma_a (chain, draw) float64 16kB ...
a (chain, draw, County) float64 1MB ...
a_county (chain, draw, County) float64 1MB ...
sigma (chain, draw) float64 16kB ...
Attributes:
created_at: 2020-07-24T18:15:12.191355
arviz_version: 0.9.0
inference_library: pymc3
inference_library_version: 3.9.2
sampling_time: 18.096983432769775
tuning_steps: 1000 Dimensions: chain : 4draw : 500g_coef : 2County : 85
Coordinates: (4)
Data variables: (7)
g
(chain, draw, g_coef)
float64
...
[4000 values with dtype=float64] za_county
(chain, draw, County)
float64
...
[170000 values with dtype=float64] b
(chain, draw)
float64
...
[2000 values with dtype=float64] sigma_a
(chain, draw)
float64
...
[2000 values with dtype=float64] a
(chain, draw, County)
float64
...
[170000 values with dtype=float64] a_county
(chain, draw, County)
float64
...
[170000 values with dtype=float64] sigma
(chain, draw)
float64
...
[2000 values with dtype=float64] Indexes: (4)
PandasIndex
PandasIndex(Index([0, 1, 2, 3], dtype='int64', name='chain')) PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
490, 491, 492, 493, 494, 495, 496, 497, 498, 499],
dtype='int64', name='draw', length=500)) PandasIndex
PandasIndex(Index(['intercept', 'slope'], dtype='object', name='g_coef')) PandasIndex
PandasIndex(Index(['AITKIN', 'ANOKA', 'BECKER', 'BELTRAMI', 'BENTON', 'BIG STONE',
'BLUE EARTH', 'BROWN', 'CARLTON', 'CARVER', 'CASS', 'CHIPPEWA',
'CHISAGO', 'CLAY', 'CLEARWATER', 'COOK', 'COTTONWOOD', 'CROW WING',
'DAKOTA', 'DODGE', 'DOUGLAS', 'FARIBAULT', 'FILLMORE', 'FREEBORN',
'GOODHUE', 'HENNEPIN', 'HOUSTON', 'HUBBARD', 'ISANTI', 'ITASCA',
'JACKSON', 'KANABEC', 'KANDIYOHI', 'KITTSON', 'KOOCHICHING',
'LAC QUI PARLE', 'LAKE', 'LAKE OF THE WOODS', 'LE SUEUR', 'LINCOLN',
'LYON', 'MAHNOMEN', 'MARSHALL', 'MARTIN', 'MCLEOD', 'MEEKER',
'MILLE LACS', 'MORRISON', 'MOWER', 'MURRAY', 'NICOLLET', 'NOBLES',
'NORMAN', 'OLMSTED', 'OTTER TAIL', 'PENNINGTON', 'PINE', 'PIPESTONE',
'POLK', 'POPE', 'RAMSEY', 'REDWOOD', 'RENVILLE', 'RICE', 'ROCK',
'ROSEAU', 'SCOTT', 'SHERBURNE', 'SIBLEY', 'ST LOUIS', 'STEARNS',
'STEELE', 'STEVENS', 'SWIFT', 'TODD', 'TRAVERSE', 'WABASHA', 'WADENA',
'WASECA', 'WASHINGTON', 'WATONWAN', 'WILKIN', 'WINONA', 'WRIGHT',
'YELLOW MEDICINE'],
dtype='object', name='County')) Attributes: (6)
created_at : 2020-07-24T18:15:12.191355 arviz_version : 0.9.0 inference_library : pymc3 inference_library_version : 3.9.2 sampling_time : 18.096983432769775 tuning_steps : 1000
posterior_predictive
<xarray.Dataset> Size: 15MB
Dimensions: (chain: 4, draw: 500, obs_id: 919)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
* obs_id (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
Data variables:
y (chain, draw, obs_id) float64 15MB ...
Attributes:
created_at: 2020-07-24T18:15:12.449843
arviz_version: 0.9.0
inference_library: pymc3
inference_library_version: 3.9.2 Dimensions: chain : 4draw : 500obs_id : 919
Coordinates: (3)
Data variables: (1)
Indexes: (3)
PandasIndex
PandasIndex(Index([0, 1, 2, 3], dtype='int64', name='chain')) PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
490, 491, 492, 493, 494, 495, 496, 497, 498, 499],
dtype='int64', name='draw', length=500)) PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
909, 910, 911, 912, 913, 914, 915, 916, 917, 918],
dtype='int64', name='obs_id', length=919)) Attributes: (4)
created_at : 2020-07-24T18:15:12.449843 arviz_version : 0.9.0 inference_library : pymc3 inference_library_version : 3.9.2
log_likelihood
<xarray.Dataset> Size: 15MB
Dimensions: (chain: 4, draw: 500, obs_id: 919)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
* obs_id (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
Data variables:
y (chain, draw, obs_id) float64 15MB ...
Attributes:
created_at: 2020-07-24T18:15:12.448264
arviz_version: 0.9.0
inference_library: pymc3
inference_library_version: 3.9.2 Dimensions: chain : 4draw : 500obs_id : 919
Coordinates: (3)
Data variables: (1)
Indexes: (3)
PandasIndex
PandasIndex(Index([0, 1, 2, 3], dtype='int64', name='chain')) PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
490, 491, 492, 493, 494, 495, 496, 497, 498, 499],
dtype='int64', name='draw', length=500)) PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
909, 910, 911, 912, 913, 914, 915, 916, 917, 918],
dtype='int64', name='obs_id', length=919)) Attributes: (4)
created_at : 2020-07-24T18:15:12.448264 arviz_version : 0.9.0 inference_library : pymc3 inference_library_version : 3.9.2
sample_stats
<xarray.Dataset> Size: 150kB
Dimensions: (chain: 4, draw: 500)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499
Data variables:
step_size_bar (chain, draw) float64 16kB ...
diverging (chain, draw) bool 2kB ...
energy (chain, draw) float64 16kB ...
tree_size (chain, draw) float64 16kB ...
mean_tree_accept (chain, draw) float64 16kB ...
step_size (chain, draw) float64 16kB ...
depth (chain, draw) int64 16kB ...
energy_error (chain, draw) float64 16kB ...
lp (chain, draw) float64 16kB ...
max_energy_error (chain, draw) float64 16kB ...
Attributes:
created_at: 2020-07-24T18:15:12.197697
arviz_version: 0.9.0
inference_library: pymc3
inference_library_version: 3.9.2
sampling_time: 18.096983432769775
tuning_steps: 1000 Dimensions:
Coordinates: (2)
Data variables: (10)
step_size_bar
(chain, draw)
float64
...
[2000 values with dtype=float64] diverging
(chain, draw)
bool
...
[2000 values with dtype=bool] energy
(chain, draw)
float64
...
[2000 values with dtype=float64] tree_size
(chain, draw)
float64
...
[2000 values with dtype=float64] mean_tree_accept
(chain, draw)
float64
...
[2000 values with dtype=float64] step_size
(chain, draw)
float64
...
[2000 values with dtype=float64] depth
(chain, draw)
int64
...
[2000 values with dtype=int64] energy_error
(chain, draw)
float64
...
[2000 values with dtype=float64] lp
(chain, draw)
float64
...
[2000 values with dtype=float64] max_energy_error
(chain, draw)
float64
...
[2000 values with dtype=float64] Indexes: (2)
PandasIndex
PandasIndex(Index([0, 1, 2, 3], dtype='int64', name='chain')) PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
490, 491, 492, 493, 494, 495, 496, 497, 498, 499],
dtype='int64', name='draw', length=500)) Attributes: (6)
created_at : 2020-07-24T18:15:12.197697 arviz_version : 0.9.0 inference_library : pymc3 inference_library_version : 3.9.2 sampling_time : 18.096983432769775 tuning_steps : 1000
prior
<xarray.Dataset> Size: 1MB
Dimensions: (chain: 1, draw: 500, County: 85, g_coef: 2)
Coordinates:
* chain (chain) int64 8B 0
* draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499
* County (County) <U17 6kB 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE'
* g_coef (g_coef) <U9 72B 'intercept' 'slope'
Data variables:
a_county (chain, draw, County) float64 340kB ...
sigma_log__ (chain, draw) float64 4kB ...
sigma_a (chain, draw) float64 4kB ...
a (chain, draw, County) float64 340kB ...
b (chain, draw) float64 4kB ...
za_county (chain, draw, County) float64 340kB ...
sigma (chain, draw) float64 4kB ...
g (chain, draw, g_coef) float64 8kB ...
sigma_a_log__ (chain, draw) float64 4kB ...
Attributes:
created_at: 2020-07-24T18:15:12.454586
arviz_version: 0.9.0
inference_library: pymc3
inference_library_version: 3.9.2 Dimensions: chain : 1draw : 500County : 85g_coef : 2
Coordinates: (4)
Data variables: (9)
a_county
(chain, draw, County)
float64
...
[42500 values with dtype=float64] sigma_log__
(chain, draw)
float64
...
[500 values with dtype=float64] sigma_a
(chain, draw)
float64
...
[500 values with dtype=float64] a
(chain, draw, County)
float64
...
[42500 values with dtype=float64] b
(chain, draw)
float64
...
[500 values with dtype=float64] za_county
(chain, draw, County)
float64
...
[42500 values with dtype=float64] sigma
(chain, draw)
float64
...
[500 values with dtype=float64] g
(chain, draw, g_coef)
float64
...
[1000 values with dtype=float64] sigma_a_log__
(chain, draw)
float64
...
[500 values with dtype=float64] Indexes: (4)
PandasIndex
PandasIndex(Index([0], dtype='int64', name='chain')) PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
490, 491, 492, 493, 494, 495, 496, 497, 498, 499],
dtype='int64', name='draw', length=500)) PandasIndex
PandasIndex(Index(['AITKIN', 'ANOKA', 'BECKER', 'BELTRAMI', 'BENTON', 'BIG STONE',
'BLUE EARTH', 'BROWN', 'CARLTON', 'CARVER', 'CASS', 'CHIPPEWA',
'CHISAGO', 'CLAY', 'CLEARWATER', 'COOK', 'COTTONWOOD', 'CROW WING',
'DAKOTA', 'DODGE', 'DOUGLAS', 'FARIBAULT', 'FILLMORE', 'FREEBORN',
'GOODHUE', 'HENNEPIN', 'HOUSTON', 'HUBBARD', 'ISANTI', 'ITASCA',
'JACKSON', 'KANABEC', 'KANDIYOHI', 'KITTSON', 'KOOCHICHING',
'LAC QUI PARLE', 'LAKE', 'LAKE OF THE WOODS', 'LE SUEUR', 'LINCOLN',
'LYON', 'MAHNOMEN', 'MARSHALL', 'MARTIN', 'MCLEOD', 'MEEKER',
'MILLE LACS', 'MORRISON', 'MOWER', 'MURRAY', 'NICOLLET', 'NOBLES',
'NORMAN', 'OLMSTED', 'OTTER TAIL', 'PENNINGTON', 'PINE', 'PIPESTONE',
'POLK', 'POPE', 'RAMSEY', 'REDWOOD', 'RENVILLE', 'RICE', 'ROCK',
'ROSEAU', 'SCOTT', 'SHERBURNE', 'SIBLEY', 'ST LOUIS', 'STEARNS',
'STEELE', 'STEVENS', 'SWIFT', 'TODD', 'TRAVERSE', 'WABASHA', 'WADENA',
'WASECA', 'WASHINGTON', 'WATONWAN', 'WILKIN', 'WINONA', 'WRIGHT',
'YELLOW MEDICINE'],
dtype='object', name='County')) PandasIndex
PandasIndex(Index(['intercept', 'slope'], dtype='object', name='g_coef')) Attributes: (4)
created_at : 2020-07-24T18:15:12.454586 arviz_version : 0.9.0 inference_library : pymc3 inference_library_version : 3.9.2
prior_predictive
<xarray.Dataset> Size: 4MB
Dimensions: (chain: 1, draw: 500, obs_id: 919)
Coordinates:
* chain (chain) int64 8B 0
* draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
* obs_id (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
Data variables:
y (chain, draw, obs_id) float64 4MB ...
Attributes:
created_at: 2020-07-24T18:15:12.457652
arviz_version: 0.9.0
inference_library: pymc3
inference_library_version: 3.9.2 Dimensions: chain : 1draw : 500obs_id : 919
Coordinates: (3)
Data variables: (1)
Indexes: (3)
PandasIndex
PandasIndex(Index([0], dtype='int64', name='chain')) PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
490, 491, 492, 493, 494, 495, 496, 497, 498, 499],
dtype='int64', name='draw', length=500)) PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
909, 910, 911, 912, 913, 914, 915, 916, 917, 918],
dtype='int64', name='obs_id', length=919)) Attributes: (4)
created_at : 2020-07-24T18:15:12.457652 arviz_version : 0.9.0 inference_library : pymc3 inference_library_version : 3.9.2
observed_data
<xarray.Dataset> Size: 15kB
Dimensions: (obs_id: 919)
Coordinates:
* obs_id (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
Data variables:
y (obs_id) float64 7kB ...
Attributes:
created_at: 2020-07-24T18:15:12.458415
arviz_version: 0.9.0
inference_library: pymc3
inference_library_version: 3.9.2 Dimensions:
Coordinates: (1)
Data variables: (1)
Indexes: (1)
PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
909, 910, 911, 912, 913, 914, 915, 916, 917, 918],
dtype='int64', name='obs_id', length=919)) Attributes: (4)
created_at : 2020-07-24T18:15:12.458415 arviz_version : 0.9.0 inference_library : pymc3 inference_library_version : 3.9.2
constant_data
<xarray.Dataset> Size: 21kB
Dimensions: (obs_id: 919, County: 85)
Coordinates:
* obs_id (obs_id) int64 7kB 0 1 2 3 4 5 6 ... 912 913 914 915 916 917 918
* County (County) <U17 6kB 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE'
Data variables:
floor_idx (obs_id) int32 4kB ...
county_idx (obs_id) int32 4kB ...
uranium (County) float64 680B ...
Attributes:
created_at: 2020-07-24T18:15:12.459832
arviz_version: 0.9.0
inference_library: pymc3
inference_library_version: 3.9.2 Dimensions:
Coordinates: (2)
Data variables: (3)
Indexes: (2)
PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
909, 910, 911, 912, 913, 914, 915, 916, 917, 918],
dtype='int64', name='obs_id', length=919)) PandasIndex
PandasIndex(Index(['AITKIN', 'ANOKA', 'BECKER', 'BELTRAMI', 'BENTON', 'BIG STONE',
'BLUE EARTH', 'BROWN', 'CARLTON', 'CARVER', 'CASS', 'CHIPPEWA',
'CHISAGO', 'CLAY', 'CLEARWATER', 'COOK', 'COTTONWOOD', 'CROW WING',
'DAKOTA', 'DODGE', 'DOUGLAS', 'FARIBAULT', 'FILLMORE', 'FREEBORN',
'GOODHUE', 'HENNEPIN', 'HOUSTON', 'HUBBARD', 'ISANTI', 'ITASCA',
'JACKSON', 'KANABEC', 'KANDIYOHI', 'KITTSON', 'KOOCHICHING',
'LAC QUI PARLE', 'LAKE', 'LAKE OF THE WOODS', 'LE SUEUR', 'LINCOLN',
'LYON', 'MAHNOMEN', 'MARSHALL', 'MARTIN', 'MCLEOD', 'MEEKER',
'MILLE LACS', 'MORRISON', 'MOWER', 'MURRAY', 'NICOLLET', 'NOBLES',
'NORMAN', 'OLMSTED', 'OTTER TAIL', 'PENNINGTON', 'PINE', 'PIPESTONE',
'POLK', 'POPE', 'RAMSEY', 'REDWOOD', 'RENVILLE', 'RICE', 'ROCK',
'ROSEAU', 'SCOTT', 'SHERBURNE', 'SIBLEY', 'ST LOUIS', 'STEARNS',
'STEELE', 'STEVENS', 'SWIFT', 'TODD', 'TRAVERSE', 'WABASHA', 'WADENA',
'WASECA', 'WASHINGTON', 'WATONWAN', 'WILKIN', 'WINONA', 'WRIGHT',
'YELLOW MEDICINE'],
dtype='object', name='County')) Attributes: (4)
created_at : 2020-07-24T18:15:12.459832 arviz_version : 0.9.0 inference_library : pymc3 inference_library_version : 3.9.2
Arviz provides a data structure called InferenceData which it uses for storing MCMC outputs. It’s worth getting to know it: there is some helpful explanation here .
At a high level, an InferenceData is a container for several xarray Dataset objects called ‘groups’. Each group contains xarray [DataArray] (https:// docs.xarray.dev/ en/stable/generated/xarray.DataArray.html) objects called ‘variables’. Each variable contains a rectangular array of values, plus the shape of the values (‘dimensions’) and labels for the dimensions (‘coordinates’).
For example, if you click on the dropdown arrows above you will see that the group posterior contains a variable called a_county that has three dimensions called chain, draw and County. There are 85 counties and the first one is labelled 'AITKIN'.
The function az.summary lets us look at some useful summary statistics, including \(\hat{R}\) , divergent transitions and MCSE.
The variable lp, which you can find in the group sample_stats is the model’s total log probability density. It’s not very meaningful on its own, but is useful for judging overall convergence. diverging counts the number of divergent transitions.
az.summary(idata.sample_stats, var_names= ["lp" , "diverging" ])
/Users/tedgro/repos/dtu-qmcm/bayesian_statistics_for_computational_biology/.venv/lib/python3.13/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
(between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/Users/tedgro/repos/dtu-qmcm/bayesian_statistics_for_computational_biology/.venv/lib/python3.13/site-packages/arviz/stats/diagnostics.py:991: RuntimeWarning: invalid value encountered in scalar divide
varsd = varvar / evar / 4
lp
-1143.808
9.272
-1160.741
-1126.532
0.426
0.242
477.0
772.0
1.01
diverging
0.000
0.000
0.000
0.000
0.000
NaN
2000.0
2000.0
NaN
In this case there were no post-warmup diverging transitions, and the \(\hat{R}\) statistic for the lp variable is pretty close to 1: great!
Sometimes it’s useful to summarise individual parameters. This can be done by pointing az.summary at the group where the parameters of interest live. In this case the group is called posterior.
az.summary(idata.posterior, var_names= ["sigma" , "g" ])
sigma
0.729
0.018
0.696
0.764
0.000
0.000
2228.0
1371.0
1.0
g[intercept]
1.494
0.036
1.427
1.562
0.001
0.001
1240.0
1566.0
1.0
g[slope]
0.700
0.093
0.526
0.880
0.003
0.002
1157.0
1066.0
1.0
Now we can start evaluating the model. First we check to see whether replicated measurements from the model’s posterior predictive distribution broadly agree with the observed measurements, using the arviz function plot_lm :
az.style.use("arviz-doc" )
az.plot_lm(
y= idata.observed_data["y" ],
x= idata.observed_data["obs_id" ],
y_hat= idata.posterior_predictive["y" ],
figsize= [12 , 5 ],
grid= False
);
The function az.loo can quickly estimate a model’s out of sample log likelihood (which we saw above is a nice default loss function), allowing a nice numerical comparison between models.
Watch out for the warning column, which can tell you if the estimation is likely to be incorrect. It’s usually a good idea to set the pointwise argument to True, as this allows for more detailed analysis at the per-observation level.
az.loo(idata, var_name= "y" , pointwise= True )
Computed from 2000 posterior samples and 919 observations log-likelihood matrix.
Estimate SE
elpd_loo -1027.18 28.85
p_loo 26.82 -
------
Pareto k diagnostic values:
Count Pct.
(-Inf, 0.70] (good) 919 100.0%
(0.70, 1] (bad) 0 0.0%
(1, Inf) (very bad) 0 0.0%
The function az.compare is useful for comparing different out of sample log likelihood estimates.
idata.log_likelihood["fake" ] = xr.DataArray(
# generate some fake log likelihoods
np.random.normal(0 , 2 , [4 , 500 , 919 ]),
coords= idata.log_likelihood.coords,
dims= idata.log_likelihood.dims
)
comparison = az.compare(
{
"real" : az.loo(idata, var_name= "y" ),
"fake" : az.loo(idata, var_name= "fake" )
}
)
comparison
/Users/tedgro/repos/dtu-qmcm/bayesian_statistics_for_computational_biology/.venv/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
real
0
-1027.176982
26.817161
0.000000
0.984786
28.848998
0.000000
False
log
fake
1
-1800.021190
3632.470122
772.844209
0.015214
3.464644
29.082427
True
log
The function az.plot_compare shows these results on a nice graph:
az.plot_compare(comparison);